{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "# from alphagen.config import *\n",
    "# from alphagen.data.tokens import *\n",
    "from alphagen_generic.features import *\n",
    "from alphagen.data.expression import *\n",
    "from gan.utils.data import get_data_by_year"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "instruments: str = \"csi300\"\n",
    "from typing import Tuple\n",
    "import json\n",
    "\n",
    "def load_alpha_pool(raw) -> Tuple[List[Expression], List[float]]:\n",
    "    exprs_raw = raw['exprs']\n",
    "    exprs = [eval(expr_raw.replace('open', 'open_').replace('$', '')) for expr_raw in exprs_raw]\n",
    "    weights = raw['weights']\n",
    "    return exprs, weights\n",
    "\n",
    "def load_alpha_pool_by_path(path: str) -> Tuple[List[Expression], List[float]]:\n",
    "    with open(path, encoding='utf-8') as f:\n",
    "        raw = json.load(f)\n",
    "        return load_alpha_pool(raw)\n",
    "    \n",
    "import os\n",
    "def load_ppo_path(path,name_prefix):\n",
    "    \n",
    "    files = os.listdir(path)\n",
    "    folder = [i for i in files if name_prefix in i][-1]\n",
    "    names = [i for i in os.listdir(f\"{path}/{folder}\") if '.json' in i]\n",
    "    name = sorted(names,key = lambda x:int(x.split('_')[0]))[-1]\n",
    "    return f\"{path}/{folder}/{name}\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Infer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from alphagen_qlib.calculator import QLibStockDataCalculator\n",
    "result = []\n",
    "paths = os.listdir('out_dso')\n",
    "name = 'test1'\n",
    "freq = 'day'\n",
    "for instruments in ['csi300','csi500']:\n",
    "    for num in [1,10,20,50,100]:\n",
    "        for seed in range(5):\n",
    "            for train_end in range(2016,2021):\n",
    "                returned = get_data_by_year(\n",
    "                    train_start = 2010,train_end=train_end,valid_year=train_end+1,test_year =train_end+2,\n",
    "                    instruments=instruments, target=target,freq=freq,\n",
    "                )\n",
    "                data_all, data,data_valid,data_valid_withhead,data_test,data_test_withhead,_ = returned\n",
    "                \n",
    "                path = f\"out_dso/{name}_{instruments}_{num}_{train_end}_{seed}/pool.json\"\n",
    "                dirname = os.path.dirname(path)\n",
    "                print(path)\n",
    "                exprs,weights = load_alpha_pool_by_path(path)\n",
    "                \n",
    "                # calculator_test = QLibStockDataCalculator(data_test, target)\n",
    "                calculator_test = QLibStockDataCalculator(data_all, target)\n",
    "\n",
    "                ensemble_value = calculator_test.make_ensemble_alpha(exprs, weights)\n",
    "                ensemble_value = ensemble_value[-data_test.n_days:]\n",
    "                \n",
    "                torch.save(ensemble_value.cpu(),f\"{dirname}/{train_end}_{num}_{seed}.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Read Experiment Result and Calculate Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "from alphagen.utils.correlation import batch_pearsonr,batch_spearmanr\n",
    "device = 'cuda:0'\n",
    "name = 'test1'\n",
    "instruments = 'csi300'\n",
    "result = []\n",
    "for seed in range(5):\n",
    "    cur_seed_ic = []\n",
    "    cur_seed_ric = []\n",
    "    \n",
    "    for num in [1,10,20,50,100]:\n",
    "        for train_end in range(2016,2021):\n",
    "            returned = get_data_by_year(\n",
    "                train_start = 2010,train_end=train_end,valid_year=train_end+1,test_year =train_end+2,\n",
    "                instruments=instruments, target=target,freq=freq,\n",
    "            )\n",
    "            data_all, data,data_valid,data_valid_withhead,data_test,data_test_withhead,_ = returned\n",
    "\n",
    "            dirname = f\"out_dso/{name}_{instruments}_{num}_{train_end}_{seed}\"\n",
    "            \n",
    "            pred = torch.load(f\"{dirname}/{train_end}_{num}_{seed}.pt\").to('cuda:0')\n",
    "            tgt = target.evaluate(data_test)\n",
    "            tgt = target.evaluate(data_all)[-data_test.n_days:,:]\n",
    "\n",
    "            ic_s = torch.nan_to_num(batch_pearsonr(pred,tgt),nan=0)\n",
    "            rank_ic_s = torch.nan_to_num(batch_spearmanr(pred,tgt),nan=0)\n",
    "\n",
    "            cur_seed_ic.append(ic_s)\n",
    "            cur_seed_ric.append(rank_ic_s)\n",
    "        ic = torch.cat(cur_seed_ic)\n",
    "        rank_ic = torch.cat(cur_seed_ric)\n",
    "        \n",
    "        ic_mean = ic.mean().item()\n",
    "        rank_ic_mean = rank_ic.mean().item()\n",
    "        ic_std = ic.std().item()\n",
    "        rank_ic_std = rank_ic.std().item()\n",
    "        tmp = dict(\n",
    "            seed = seed,\n",
    "            num = num,\n",
    "            ic = ic_mean,\n",
    "            ric = rank_ic_mean,\n",
    "            icir = ic_mean/ic_std,\n",
    "            ricir = rank_ic_mean/rank_ic_std,\n",
    "        )\n",
    "        result.append(tmp)\n",
    "\n",
    "import pandas as pd\n",
    "exp_result = pd.DataFrame(result).groupby(['num','seed']).mean().groupby('num').agg(['mean','std'])\n",
    "print(exp_result)\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py38n1",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
